import com.alibaba.fastjson.JSONObject;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.*;
import org.apache.spark.util.LongAccumulator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.dto.TaskInstanceDTO;
import org.zjvis.datascience.common.exception.BaseErrorCode;
import org.zjvis.datascience.common.exception.DataScienceException;
import org.zjvis.datascience.common.graph.model.TupleWrapper;
import org.zjvis.datascience.common.vo.graph.*;
import org.zjvis.datascience.spark.util.UtilTool;
import scala.Tuple2;
import scala.collection.JavaConversions;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;
import java.util.stream.Collectors;

/**
 * @description 图构建测试
 * @date 2021-12-23
 */
public class GraphBuild implements Serializable {
    private final static Logger logger = LoggerFactory.getLogger("GraphBuild");
    private SparkSession sparkSession;

    private TaskInstanceDTO instance;
    private String dataJson;

    private List<CategoryVO> categories;
    private List<EdgeVO> edges;


    private Map<String, CategoryVO> cIdMap;
    private Map<String, EdgeVO> eIdMap;
    private Map<String, Map<Object, NodeVO>> cId2nIdMap;
    private Map<String, Map<String, LinkVO>> eId2lIdMap;
    private LongAccumulator lIdCnt;

    private Long graphId;
    private GraphOutputVO outputHelper;

    private Map<String, Map<String, Boolean>> multiValueHelper;
    private Map<String, TupleWrapper<String, CategoryAttrVO>> attrIdHelper;
    private Map<String, Dataset<Row>> datasetMap;

    public GraphBuild(SparkSession sparkSession, Long graphId, String dataJson) {
        this.sparkSession = sparkSession;
        this.graphId = graphId;
        this.dataJson = dataJson;
        initHelper();
    }

    public static String getTaskInput(Connection conn, Long graphId) {
        Statement statement = null;
        String dataJson = "";
        if (conn != null) {
            String sql = String.format("select data_json from aiworks.graph where id = %s", graphId);
            try {
                statement = conn.createStatement();
                ResultSet rs = statement.executeQuery(sql);
                while (rs.next()) {
                    dataJson = rs.getString(1);
                }
            } catch (SQLException e) {
                e.printStackTrace();
            } finally {
                if (null != statement) {
                    try {
                        statement.close();
                    } catch (SQLException e) {

                    }
                }
            }
        }
        return dataJson;
    }

    public void initHelper() {
        cIdMap = new HashMap<>();
        eIdMap = new HashMap<>();
        cId2nIdMap = new HashMap<>();
        eId2lIdMap = new HashMap<>();
        multiValueHelper = new HashMap<>();
        attrIdHelper = new HashMap<>();
        JSONObject dataJsonObj = JSONObject.parseObject(dataJson);
        categories = dataJsonObj.getJSONArray("categories").toJavaList(CategoryVO.class);
        edges = dataJsonObj.getJSONArray("edges").toJavaList(EdgeVO.class);
        categories.forEach(category -> {
            cIdMap.put(category.getId(), category);
            cId2nIdMap.put(category.getId(), new HashMap<>());
            multiValueHelper.put(category.getId(), new HashMap<>());
            attrIdHelper.put(category.getKeyAttr().getId(), new TupleWrapper<>(category.getId(), category.getKeyAttr()));
            for (CategoryAttrVO categoryAttr : category.getAttrs()) {
                multiValueHelper.get(category.getId()).put(categoryAttr.getName(), false);
                attrIdHelper.put(categoryAttr.getId(), new TupleWrapper<>(category.getId(), categoryAttr));
            }
        });
        edges.forEach(edge -> {
            eIdMap.put(edge.getId(), edge);
            eId2lIdMap.put(edge.getId(), new HashMap<>());
            multiValueHelper.put(edge.getId(), new HashMap<>());
            if (edge.getAttrIds() != null) {
                for (String attrId : edge.getAttrIds()) {
                    CategoryAttrVO categoryAttr = attrIdHelper.get(attrId).getRight();
                    multiValueHelper.get(edge.getId()).put(categoryAttr.getName(), false);
                    if (edge.getAttrs() == null) {
                        edge.setAttrs(new ArrayList<>());
                    }
                    edge.getAttrs().add(categoryAttr);
                }
            }
        });

        outputHelper = new GraphOutputVO();
        datasetMap = new HashMap<>();
        lIdCnt = sparkSession.sparkContext().longAccumulator();
    }

    public boolean loadGraphData() {
        Long rowNumIncrease = 0L;
        List<NodeVO> nodes = new ArrayList<>();
        List<LinkVO> links = new ArrayList<>();
        for (CategoryVO category : categories) {
            String tableName = category.tableName();
            datasetMap.put(category.getId(), UtilTool.readFromGreenPlum(sparkSession, tableName, UtilTool.DEFAULT_ID_COL));
            JavaRDD<NodeVO> nodeRdd = buildCategory(category, rowNumIncrease);
            JavaPairRDD<String, Boolean> attrIsMultiRdd = nodeRdd.flatMap(node -> node.getAttrs().iterator()).mapToPair(attr -> new Tuple2<>(
                    attr.getKey(),
                    attr.getValue() instanceof Collection && ((Collection<?>) attr.getValue()).size() > 1))
                    .reduceByKey((b1, b2) -> b1 || b2);
            List<Tuple2<String, Boolean>> IsMultiCount = attrIsMultiRdd.collect();
            for (Tuple2<String, Boolean> tuple : IsMultiCount) {
                multiValueHelper.get(category.getId()).put(tuple._1(), tuple._2());
            }
            nodes.addAll(nodeRdd.collect());
            rowNumIncrease += nodeRdd.count();
            cId2nIdMap.put(category.getId(),
                    nodes.parallelStream().collect(Collectors.toMap(node -> node.getAttrs().get(0).getValue(), node -> node))
            );
        }

        rowNumIncrease = 0L;
        for (String eId : eIdMap.keySet()) {
            EdgeVO edge = eIdMap.get(eId);
            String srcCid = edge.getSource();
            String tarCid = edge.getTarget();
            CategoryVO srcCategory = cIdMap.get(srcCid);
            CategoryVO tarCategory = cIdMap.get(tarCid);
            String srcTableName = srcCategory.tableName();
            String tarTableName = tarCategory.tableName();
            JavaRDD<LinkVO> linkRdd;
            if (srcCid.equals(tarCid)) {
                //自连边
                linkRdd = buildPairFromSelfEdge(srcCategory, edge, rowNumIncrease);
            } else {
                linkRdd = buildPairFromJoinTable(srcCategory, tarCategory, edge, rowNumIncrease);
            }
            JavaPairRDD<String, Boolean> attrIsMultiRdd = linkRdd.flatMap(link -> link.getAttrs().iterator()).mapToPair(attr -> new Tuple2<>(
                    attr.getKey(),
                    attr.getValue() instanceof Collection && ((Collection<?>) attr.getValue()).size() > 1))
                    .reduceByKey((b1, b2) -> b1 || b2);
            List<Tuple2<String, Boolean>> IsMultiCount = attrIsMultiRdd.collect();
            for (Tuple2<String, Boolean> tuple : IsMultiCount) {
                multiValueHelper.get(edge.getId()).put(tuple._1(), tuple._2());
            }
            rowNumIncrease += linkRdd.count();
            links.addAll(linkRdd.collect());
        }
        return true;
    }

    public JavaRDD<NodeVO> buildCategory(CategoryVO category, Long rowNumIncrease) {
        Dataset<Row> dataset = datasetMap.get(category.getId());
        String keyAttrColumnName = category.getKeyAttr().getColumn();
        String labelColumnName = category.getOriginLabel().getColumn();
        Dataset<Row> result;
        Column keyAttrCol = new Column(keyAttrColumnName).as("keyAttr");
        Column labelCol = new Column(labelColumnName).as("labelAttr");
        if (category.getAttrs().size() == 0 || category.getAttrs() == null) {
            result = dataset.select(keyAttrCol, labelCol).dropDuplicates("keyAttr")
                    .withColumn("rowNum", functions.expr(String.format("ROW_NUMBER() OVER (ORDER BY 1) + %s", rowNumIncrease)));
        } else {
            List<Column> attrCollects = category.getAttrs().stream()
                    .map(CategoryAttrVO::getColumn)
                    .map(colName -> functions.collect_set(colName).as(colName))
                    .collect(Collectors.toList());

            List<Column> attrColumnNames = category.getAttrs().stream()
                    .map(CategoryAttrVO::getColumn)
                    .map(Column::new)
                    .collect(Collectors.toList());

            Dataset<Row> label = dataset.select(keyAttrCol, labelCol).dropDuplicates("keyAttr").as("label");
            Dataset<Row> attr = attrCollects.size() > 1 ?
                    dataset.groupBy(keyAttrCol).agg(attrCollects.get(0), JavaConversions.asScalaBuffer(attrCollects.subList(1, attrCollects.size()))).as("attr")
                    : dataset.groupBy(keyAttrCol).agg(attrCollects.get(0)).as("attr");
            attrColumnNames.add(0, new Column("label.labelAttr"));
            attrColumnNames.add(0, new Column("label.keyAttr"));
            result = label.join(attr, new Column("label.keyAttr").equalTo(new Column("attr.keyAttr")))
                    .select(JavaConversions.asScalaBuffer(attrColumnNames))
                    .withColumn("rowNum", functions.expr(String.format("ROW_NUMBER() OVER (ORDER BY 1) + %s", rowNumIncrease)));
        }
        return result.toJavaRDD().map(new createNode(category));
    }


    public JavaRDD<LinkVO> buildPairFromJoinTable(CategoryVO srcCategory, CategoryVO tarCategory, EdgeVO edge, Long rowNumIncrease) {
        JSONObject conf = edge.getJoinConfigure();
        if (conf == null) {
            throw DataScienceException.of(BaseErrorCode.GRAPH_BUILD_ERROR, "edge joinConfigure error, conf=null");
        }
        String leftHeaderName = conf.getString("leftHeaderName");
        String rightHeaderName = conf.getString("rightHeaderName");
        if (leftHeaderName == null || rightHeaderName == null) {
            throw DataScienceException.of(BaseErrorCode.GRAPH_BUILD_ERROR, "edge joinConfigure error, leftHeaderName = {},  rightHeaderName = {}");
        }
        Dataset<Row> result = getJoinRelation(srcCategory.getId(), tarCategory.getId(),
                leftHeaderName, rightHeaderName,
                srcCategory.getKeyAttr().getColumn(), tarCategory.getKeyAttr().getColumn(),
                rowNumIncrease);
        return result.toJavaRDD().map(new createLink(edge));
    }

    public JavaRDD<LinkVO> buildPairFromSelfEdge(CategoryVO category, EdgeVO edge, Long rowNumIncrease) {
        JSONObject conf = edge.getJoinConfigure();
        if (conf == null) {
            throw DataScienceException.of(BaseErrorCode.GRAPH_BUILD_ERROR, "edge joinConfigure error, conf=null");
        }
        String leftHeaderName = conf.getString("leftHeaderName");
        String rightHeaderName = conf.getString("rightHeaderName");
        if (leftHeaderName == null || rightHeaderName == null) {
            throw DataScienceException.of(BaseErrorCode.GRAPH_BUILD_ERROR, "edge joinConfigure error, leftHeaderName = {},  rightHeaderName = {}");
        }
        Dataset<Row> result = getJoinRelation(category.getId(), category.getId(),
                leftHeaderName, rightHeaderName,
                category.getKeyAttr().getColumn(), category.getKeyAttr().getColumn(),
                rowNumIncrease);
        return result.toJavaRDD().map(new createLink(edge));
    }

    class createNode implements Function<Row, NodeVO> {
        private CategoryVO category;

        public createNode(CategoryVO category) {
            this.category = category;
        }

        public NodeVO call(Row row) {
            Object keyAttrValue = row.get(row.fieldIndex("keyAttr"));
            Integer rowNum = row.getInt(row.fieldIndex("rowNum"));
            String nId = String.format("n_%s_%s", graphId, rowNum);
            //主属性
            List<AttrVO> attrs = new ArrayList<>();
            attrs.add(AttrVO.builder().
                    key(category.getKeyAttr().getName()).
                    value(keyAttrValue).
                    type(category.getKeyAttr().getType()).
                    build());
            String cId = category.getId();
            //副属性
            for (CategoryAttrVO categoryAttr : category.getAttrs()) {
                String name = categoryAttr.getName();
                List<?> val = row.getList(row.fieldIndex(name));
                attrs.add(AttrVO.builder().
                        key(categoryAttr.getName()).
                        value(val).
                        type(categoryAttr.getType()).
                        build());
            }
            NodeVO node = new NodeVO();
            node.setId(nId);
            node.setOrderId(rowNum);
            node.setLabel(row.get(row.fieldIndex("labelAttr")).toString());
            node.setCategoryId(cId);
            node.setAttrs(attrs);
            node.getStyle().put("fill", category.getStyle().getString("fill"));
            return node;
        }
    }

    class createLink implements Function<Row, LinkVO> {
        private EdgeVO edge;

        public createLink(EdgeVO edge) {
            this.edge = edge;
        }

        public LinkVO call(Row row) {
            Object srcValue = row.get(row.fieldIndex("src"));
            Object tarValue = row.get(row.fieldIndex("tar"));
            Integer rowNum = row.getInt(row.fieldIndex("rowNum"));
            NodeVO src = cId2nIdMap.get(edge.getSource()).get(srcValue);
            NodeVO tar = cId2nIdMap.get(edge.getTarget()).get(tarValue);
            String srcNid = src.getId();
            String tarNid = tar.getId();
            String eId = edge.getId();
            String lId = String.format("l_%s_%s", graphId, rowNum);
            LinkVO link = new LinkVO();
            link.setId(lId);
            link.setOrderId(lIdCnt.value().intValue());
            link.setSource(srcNid);
            link.setTarget(tarNid);
            link.setEdgeId(eId);
            link.setLabel(eIdMap.get(eId).getLabel());
            link.setDirected(edge.getStyle() == null || edge.getStyle().getBoolean("endArrow"));
            //set weight
            JSONObject weightCfg = edge.getWeightCfg();
            if (weightCfg != null && !weightCfg.isEmpty() && weightCfg.get("attrId") != null) {
                String attrId = weightCfg.getString("attrId");
                String cid = attrIdHelper.get(attrId).getLeft();
                CategoryAttrVO categoryAttr = attrIdHelper.get(attrId).getRight();
                NodeVO weightNode = cid.equals(edge.getSource()) ? src : tar;
                Optional<AttrVO> opt = weightNode.getAttrs().stream().filter(attr -> categoryAttr.getColumn().equals(attr.getKey())).findFirst();
                if (opt.isPresent()) {
                    Object weight = opt.get().getValue();
                    try {
                        link.setWeight(Double.parseDouble(weight.toString()));
                    } catch (NumberFormatException e) {
                        link.setWeight(1.0);
                    }
                }
            } else {
                link.setWeight(1.0);
            }

            //边属性
            List<AttrVO> attrs = new ArrayList<>();
            if (edge.getAttrIds() != null) {
                for (String attrId : edge.getAttrIds()) {
                    String cid = attrIdHelper.get(attrId).getLeft();
                    CategoryAttrVO categoryAttr = attrIdHelper.get(attrId).getRight();
                    NodeVO weightNode = cid.equals(edge.getSource()) ? src : tar;
                    Optional<AttrVO> opt = weightNode.getAttrs().stream().filter(attr -> categoryAttr.getColumn().equals(attr.getKey())).findFirst();
                    if (opt.isPresent()) {
                        Object val = opt.get().getValue();
                        if (val instanceof Collection) {
                            if (!multiValueHelper.get(eId).get(categoryAttr.getName()) && ((Collection) val).size() > 1) {
                                multiValueHelper.get(eId).put(categoryAttr.getName(), true);
                            }
                        }
                        attrs.add(AttrVO.builder().
                                key(categoryAttr.getName()).
                                value(val).
                                type(categoryAttr.getType()).
                                build());
                    }
                }
            }
            link.setAttrs(attrs);
            link.getStyle().put("stroke", edge.getStyle().getString("stroke"));
            return link;
        }

    }

    public Dataset<Row> getJoinRelation(
            String srcId, String tarId,
            String leftHeaderName, String rightHeaderName,
            String leftAttrName, String rightAttrName,
            Long rowNumIncrease) {
        Dataset<Row> a = datasetMap.get(srcId);
        Dataset<Row> b = datasetMap.get(tarId);
        Column leftJoinCol = new Column(leftHeaderName).as("joinkey");
        Column rightJoinCol = new Column(rightHeaderName).as("joinkey");
        Column leftAttrCol = new Column(leftAttrName).as("attr");
        Column rightAttrCol = new Column(rightAttrName).as("attr");
        Dataset<Row> srcDS = a.select(leftJoinCol, leftAttrCol)
                .distinct().as("src");
        Dataset<Row> tarDS = b.select(rightJoinCol, rightAttrCol)
                .distinct().as("tar");
        Column srcValue = new Column("src.attr").as("src");
        Column tarValue = new Column("tar.attr").as("tar");
        Dataset<Row> res = srcDS.join(tarDS, new Column("src.joinkey").equalTo(new Column("tar.joinkey")))
                .select(srcValue, tarValue).distinct()
                .withColumn("rowNum", functions.expr(String.format("ROW_NUMBER() OVER (ORDER BY 1) + %s", rowNumIncrease)));
        return res;
    }

    public static void main(String[] args) throws Exception {
        SparkConf conf = new SparkConf();
        conf.set("spark.driver.maxResultSize", "4g");
        SparkSession sparkSession = SparkSession.builder()
                .master("local[2]")
                .config(conf)
                .appName("SparkGraphTest").getOrCreate();
        Connection conn = UtilTool.getConn();
        Long graphId = 2817L;
        String dataJson = GraphBuild.getTaskInput(conn, graphId);
        GraphBuild build = new GraphBuild(sparkSession, graphId, dataJson);
        build.loadGraphData();
        sparkSession.stop();
    }
}
